{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Grover's algorithm\n",
    "\n",
    "In this notebook, we are going to implement Grover's algorithm using Qiskit. \n",
    "\n",
    "We will start by defining a function that, given the quantum registers, creates the initial part of the circuit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "from qiskit import *\n",
    "from qiskit.visualization import *\n",
    "from qiskit.tools.monitor import *\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "def init_grover(q):\n",
    "\n",
    "    circ = QuantumCircuit(q) \n",
    "\n",
    "    n = len(q)\n",
    "    \n",
    "    circ.x(n-1) # The qubit that receives the oracle output must be set to |1>\n",
    "    for i in range(n):\n",
    "        circ.h(q[i])\n",
    "    \n",
    "    circ.barrier()\n",
    "    \n",
    "    return circ"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we define a function for the difussion operator on $q$ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def difussion(q):\n",
    "    \n",
    "    circ = QuantumCircuit(q) \n",
    "    \n",
    "    # Diffusion operator\n",
    "    \n",
    "    n = len(q)\n",
    "    \n",
    "    for i in range(n-1):\n",
    "        circ.h(q[i])\n",
    "        \n",
    "    for i in range(n-1):\n",
    "        circ.x(q[i])\n",
    "        \n",
    "    # To implement a multicontrolled Z we use a multicontrolled Z rotation\n",
    "    \n",
    "    mcz = QuantumCircuit(q, name = 'cZ')\n",
    "    if(n>2):\n",
    "        mcz.mcrz(np.pi,q[0:n-2],q[n-2])\n",
    "    else:\n",
    "        mcz.z(q[0]) # If there is only input qubit for the oracle, we don't have controls \n",
    "    \n",
    "    circ.append(mcz.to_instruction(),q)\n",
    "    \n",
    "    for i in range(n-1):\n",
    "        circ.x(q[i])\n",
    "\n",
    "    for i in range(n-1):\n",
    "        circ.h(q[i])\n",
    "\n",
    "    circ.barrier()\n",
    "    \n",
    "    return circ"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To test it out, we are going to define a function that implements an oracle for a boolean function that only returns 1 when the input is all ones"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ones(q):\n",
    "\n",
    "    # We will use a multicontrolled X gate \n",
    "    \n",
    "    circ = QuantumCircuit(q)\n",
    "    \n",
    "    n = len(q)\n",
    "    \n",
    "    circ.mcx(q[0:n-1],q[n-1])\n",
    "    \n",
    "    return circ"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we define a function that constructs the complete circuit given the number of iterations, the number of qubits and the oracle. We also add a parameter to select whether we want measurements or not"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def grover(n, oracle, it = 10, measurement = True):\n",
    "    \n",
    "    q = QuantumRegister(n, name = 'q') # We create the quantum register\n",
    "    if(measurement):\n",
    "        c = ClassicalRegister(n-1,name='c') # We are only going to measure the qubits that are the input to the oracle\n",
    "        circ = QuantumCircuit(q,c) # We create the circuit\n",
    "    else:\n",
    "        circ = QuantumCircuit(q) # Circuit without measurements\n",
    "    \n",
    "    circ += init_grover(q) # We add the initial part\n",
    "    \n",
    "    for _ in range(it): # We add it repetitions of the oracle plus the diffusion operator\n",
    "        circ += oracle(q)\n",
    "        circ += difussion(q)\n",
    "        \n",
    "    if(measurement):  # Measurements\n",
    "        circ.measure(q[0:n-1],c)\n",
    "            \n",
    "        \n",
    "    return circ"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are going to check the behaviour of the algorithm starting with a simple case: when we have one marked element out of four (two input qubits), we can find it with just one iteration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 3\n",
    "\n",
    "circ_grover = grover(n,ones,1)\n",
    "circ_grover.draw(output = 'mpl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "backend = Aer.get_backend('qasm_simulator')\n",
    "job = execute(circ_grover, backend)\n",
    "counts = job.result().get_counts()\n",
    "print(counts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we are going to see the evolution of the probability of finding a marked element when the number of iterations changes. We will consider four input qubits and a number of iterations from 0 to 20."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "n = 5\n",
    "max_it = 20\n",
    "shots = 1000\n",
    "backend = Aer.get_backend('qasm_simulator')\n",
    "target=(n-1)*'1' # The marked element as a string, to retrieve its probability\n",
    "\n",
    "prob = [0.0 for _ in range(max_it+1)]\n",
    "\n",
    "for it in range(max_it+1):\n",
    "    circ_grover2 = grover(n,ones,it)\n",
    "    job = execute(circ_grover2, backend, shots = shots)\n",
    "    counts = job.result().get_counts()\n",
    "    if target in counts.keys():\n",
    "        prob[it]=counts[target]/shots \n",
    "    else:\n",
    "        prob[it] = 0 # Element not found\n",
    "       \n",
    "iter = range(max_it+1)\n",
    "plt.xlabel('Iterations')\n",
    "plt.ylabel('Probability')\n",
    "plt.plot(iter,prob)\n",
    "plt.show()\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, for the integer that is the closest to $\\frac{\\pi}{4}\\sqrt{2^{n-1}}$, which in this case is 3, the probability of finding the marked element is high."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
